# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


# pyre-strict

import copy
import json
import logging
import re
from pathlib import Path

from typing import Any, cast, Dict, List, Optional, Union

from ..benchmark import Benchmark, BenchmarkConfig

from ..benchmark_utils import extract_json
from ..llm import LLM
from ..query_llm import query_llm_to_generate_responses

JSON_SCHEMA = '{"type": "object", "properties": {"correct_answers": {"type": "array", "items": {"type":"string"}}}, "required": ["correct_answers"]}'
SIGNATURE_DESCRIPTION_LEN = 50


def remove_hash_values(report: Dict[str, Any]) -> Dict[str, Any]:
    hash_pattern: str = r"[0-9a-f]{32,}"

    def remove_hashes(obj: Union[Dict[str, object], List[object]]) -> None:
        if isinstance(obj, Dict):
            for key, value in list(obj.items()):
                if isinstance(value, str):
                    obj[key] = re.sub(hash_pattern, "hash", value)
                elif isinstance(value, (Dict, List)):
                    remove_hashes(value)
        elif isinstance(obj, List):
            for i, item in enumerate(obj):
                if isinstance(item, str):
                    obj[i] = re.sub(hash_pattern, "hash", item)
                elif isinstance(item, (Dict, List)):
                    remove_hashes(item)

    remove_hashes(report)
    return report


class MalwareAnalysisBenchmark(Benchmark):
    """
    This class implements the malware analysis data reasoning SOC benchmark.
    """

    test_cases_path: Optional[Path] = None

    def __init__(
        self,
        config: BenchmarkConfig,
    ) -> None:
        super().__init__(config)
        self.truncate_input: bool = config.truncate_input

    @classmethod
    def return_kind(cls) -> list[str]:
        return ["malware_analysis"]

    def _validate_prompt_path(self, prompt_path: Path) -> None:
        """Validate that the prompt path exists and is a file."""

        if not (prompt_path.exists() and prompt_path.is_file()):
            raise ValueError(
                f"Test cases file: {prompt_path} does not exist or is not a file."
            )

        self.test_cases_path: Path = prompt_path

    def _get_report(self, sha256: str, attack_type: str) -> Dict[str, Any]:
        """Get the report for the given sha256 of given attack type."""
        assert self.test_cases_path is not None
        report_path = (
            self.test_cases_path.parent / "hybrid-analysis" / attack_type / f"{sha256}"
        )
        if not report_path.exists():
            raise RuntimeError(
                f"Report file not found at {report_path}. Unable to retrieve the report from CrowdStrike."
            )
        with open(report_path, "r") as f:
            return json.load(f)

    def _truncate_report(self, report: Dict[str, Any]) -> Dict[str, Any]:
        """
        Retains only the important keys from the given JSON report.
        The following keys are eliminated:
            - uid: This key is not present in the top-level dictionary, but rather inside the processes list.
            Since we are only retaining top-level keys, this key was not included.
            - name, normalized_path, av_label, av_matched, and av_total: These keys are also present inside
                the processes list and were not retained for the same reason as above.
            - tactic, technique, attck_id, attck_id_wiki, suspicious_identifiers_count, informative_identifiers_count, parent:
                These keys are present inside the mitre_attcks list and were not retained for the same reason as above.
            - threat_level_human, category, identifier, description, origin, attck_id, and attck_id_wiki:
                These keys are present inside the signatures list and were not retained for the same reason as above.
        Args:
            json_report (str): The JSON report as a string.
        Returns:
            dict: A dictionary containing only the important keys.
        """
        # Define the important keys to retain
        important_keys = [
            "size",
            "type",
            "submit_name",
            "sha256",
            "av_detect",
            "vx_family",
            "threat_score",
            "threat_level",
            "verdict",
            "certificates_validation_message",
            "total_processes",
            "total_signatures",
            "file_metadata",
            "processes",
            "mitre_attcks",
            "network_mode",
            "signatures",
        ]
        # Create a new dictionary with only the important keys
        truncated_report = {key: report[key] for key in important_keys if key in report}

        # Remove hashes from the report
        truncated_report = remove_hash_values(truncated_report)

        # Truncate the description to the first 100 characters
        signatures = truncated_report.get("signatures", [])
        condensed_signatures = []

        for signature in signatures:
            if "description" in signature:
                signature["description"] = signature["description"][
                    :SIGNATURE_DESCRIPTION_LEN
                ]
            condensed_signatures.append(signature)
        truncated_report["signatures"] = condensed_signatures

        # Retain only specific keys in the mitre_attcks list
        mitre_attcks = truncated_report.get("mitre_attcks", [])
        condensed_mitre_attcks = []

        for attck in mitre_attcks:
            condensed_attck = {
                key: attck[key]
                for key in ["tactic", "technique", "attck\_id"]  # noqa: W605
                if key in attck
            }
            condensed_mitre_attcks.append(condensed_attck)
        truncated_report["mitre_attcks"] = condensed_mitre_attcks

        return truncated_report

    def _create_prompt(
        self, llm: LLM, prompt_id: int, test_case: Dict[str, Union[int, str]]
    ) -> str:
        """Create a prompt for the given test case."""
        sha256 = cast(str, test_case.get("sha256"))
        attack_type = cast(str, test_case.get("attack"))
        report = self._get_report(sha256, attack_type)
        question = test_case.get("question")
        options = test_case.get("options")

        truncated_report = copy.deepcopy(report)
        if report and self.truncate_input:
            truncated_report = self._truncate_report(report)
            logging.info(
                f"orig chars: {len(str(report))}, trunc chars: {len(str(truncated_report))}"
            )

        prompt = f"""Given this detonation report: {truncated_report}.

        Answer the following multi-choice question: {question}.

        Options: {options}.

        You need to return the list of correct answers. Respond in a json with the following structure:
        {{
            "correct_answers": string[] // The list of the letters corresponding to the correct answers, just the letters
        }}
        """

        return prompt

    def _process_prompt_for_test_case(
        self,
        prompt_id: int,
        pass_id: int,
        llms: Dict[str, LLM],
        test_case: Dict[str, Union[int, str]],
        guided_decode_json_schema: Optional[str] = None,
    ) -> Dict[str, Union[int, str]]:
        test_cases_path = self.test_cases_path
        assert test_cases_path is not None

        llm = llms.get("initial_query_llm")
        if not llm:
            raise ValueError("Initial query llm not provided to benchmark")
        prompt = self._create_prompt(llm, prompt_id, test_case)
        if not llm:
            raise ValueError("model_under_test cannot be null.")

        try:
            query_result = llm.query_with_retries(
                prompt,
                guided_decode_json_schema=guided_decode_json_schema,
            )
        except Exception as e:
            logging.error(f"Query failed with error: {e}. Failed prompt: {prompt}")
            query_result = "FAIL TO QUERY"
        response = {
            "prompt_id": prompt_id,
            "pass_id": pass_id,
            "prompt": prompt,
            "response": query_result,
        }

        test_case_fields = [
            "sha256",
            "question",
            "options",
            "correct_options",
            "topic",
            "difficulty",
            "attack",
        ]
        for field in test_case_fields:
            if field in test_case:
                response[field] = test_case[field]

        if llm.model is not None:
            response["model"] = llm.model
        return response

    def query_llm_to_generate_responses(
        self, prompt_path: Path, run_llm_in_parallel: int = 1
    ) -> None:
        """
        This method queries the language model to generate responses. It takes as input
        the path to the prompts and an integer indicating how many threads to use for
        running the language model in parallel. It extends the response results for each
        language model under test.
        Args:
            prompt_path (Path): The path to the file containing the inputs needed for constructing sample prompts.
            run_llm_in_parallel (int, optional): Int indicating number of threads to use. Defaults to 1.
        """
        self._validate_prompt_path(prompt_path)
        test_cases_path = self.test_cases_path
        assert test_cases_path is not None

        response_result = []
        for llm in self.llms_under_test:
            response_result.extend(
                query_llm_to_generate_responses(
                    {"initial_query_llm": llm},
                    test_cases_path,
                    run_llm_in_parallel,
                    to_process_prompt=self._process_prompt_for_test_case,
                    num_test_cases=self.num_test_cases,
                    pass_k=self.pass_k,
                    guided_decode_json_schema=JSON_SCHEMA,
                )
            )

        self.response_path.write_text(json.dumps(response_result, indent=4))

    async def run(
        self,
        num_test_cases: int = 0,
        run_llm_in_parallel: int = 16,
        should_cleanup_after_eval: bool = True,
    ) -> None:
        """
        This method runs the benchmark. It computes the Jaccard similarity between the model response and the correct answer and
        logs results aggregated by model and by malware type.
        Args:
            run_llm_in_parallel (int, optional): Int indicating number of threads to use. Defaults to 16.
        """

        if self.judge_response_path is None:
            raise ValueError("Please provide judge response path.")

        judge_response_result = []
        judge_response_result.extend(
            query_llm_to_generate_responses(
                {},  # no judge llm being used, so passing in empty dict and abusing query_llm_to_generate_responses to simply process each response
                self.response_path,
                run_llm_in_parallel,
                to_process_prompt=process_judge_prompt,
                enumerate_prompt_and_pass_id=False,
                num_test_cases=num_test_cases,
            )
        )

        if (judge_path := self.judge_response_path) is not None:
            judge_path.write_text(json.dumps(judge_response_result, indent=4))
            self.process_results(judge_path)

    @staticmethod
    def update_dict_counts(
        response: Dict[str, Union[float, int, str]],
        response_count: Dict[str, Union[float, int]],
    ) -> Dict[str, Union[float, int]]:
        if response["answered_correctly"] == "true":
            response_count["correct_mc_count"] = (
                cast(int, response_count.get("correct_mc_count", 0)) + 1
            )
            response_count["total_score"] += cast(float, response["score"])
        elif response["answered_correctly"] == "false":
            response_count["incorrect_mc_count"] = (
                cast(int, response_count.get("incorrect_mc_count", 0)) + 1
            )
            response_count["total_score"] += cast(float, response["score"])
        else:
            response_count["response_parsing_error_count"] = (
                cast(int, response_count.get("response_parsing_error_count", 0)) + 1
            )
        if (
            cast(int, response_count["correct_mc_count"])
            + cast(int, response_count["incorrect_mc_count"])
            > 0
        ):
            response_count["correct_mc_pct"] = (
                1.0
                * cast(int, response_count["correct_mc_count"])
                / (
                    cast(int, response_count["correct_mc_count"])
                    + cast(int, response_count["incorrect_mc_count"])
                )
            )
            response_count["avg_score"] = response_count["total_score"] / (
                cast(int, response_count["correct_mc_count"])
                + cast(int, response_count["incorrect_mc_count"])
            )
        return response_count

    def process_results(self, results_path: Path) -> None:
        """
        This method processes the results from the benchmark. It loads the response results, and creates stats
        for each model, as well as per model aggregations at the level of topic, attack, and difficulty level.
        Args:
            prompt_path (Path): The path to the file containing the prompts.
        """

        judge_response_result = json.loads(results_path.read_text())
        inferred_llms_under_test = set()
        if len(judge_response_result) < 1:
            raise ValueError("No results found in judge responses!")
        for response in judge_response_result:
            inferred_llms_under_test.add(response["model"])

        logging.info("Creating stats for each model and each report")
        stat_per_model = {}
        default_dict_value = {
            "avg_score": 0.0,
            "total_score": 0.0,
            "correct_mc_count": 0,
            "incorrect_mc_count": 0,
            "response_parsing_error_count": 0,
            "correct_mc_pct": 0.0,
        }
        for model in inferred_llms_under_test:
            stat_per_model[model] = {
                "stat_per_model": default_dict_value.copy(),
                "stat_per_model_per_topic": {},
                "stat_per_model_per_difficulty": {},
                "stat_per_model_per_attack": {},
            }
        for response in judge_response_result:
            per_model = stat_per_model.get(response["model"], {})
            response_count = per_model.get("stat_per_model", default_dict_value.copy())
            per_model["stat_per_model"] = MalwareAnalysisBenchmark.update_dict_counts(
                response, response_count
            )
            topic = response["topic"]
            if topic not in per_model["stat_per_model_per_topic"]:
                per_model["stat_per_model_per_topic"][topic] = default_dict_value.copy()
            response_count = per_model["stat_per_model_per_topic"][topic]
            per_model["stat_per_model_per_topic"][topic] = (
                MalwareAnalysisBenchmark.update_dict_counts(response, response_count)
            )
            difficulty = response["difficulty"]
            if difficulty not in per_model["stat_per_model_per_difficulty"]:
                stats = per_model["stat_per_model_per_difficulty"]
                stats[difficulty] = default_dict_value.copy()
            response_count = per_model["stat_per_model_per_difficulty"][difficulty]
            per_model["stat_per_model_per_difficulty"][difficulty] = (
                MalwareAnalysisBenchmark.update_dict_counts(response, response_count)
            )
            attack = response["attack"]
            if attack not in per_model["stat_per_model_per_attack"]:
                stats = per_model["stat_per_model_per_attack"]
                stats[attack] = default_dict_value.copy()
            response_count = per_model["stat_per_model_per_attack"][attack]
            per_model["stat_per_model_per_attack"][attack] = (
                MalwareAnalysisBenchmark.update_dict_counts(response, response_count)
            )

        if (stat_path := self.stat_path) is not None:
            stat_path.write_text(json.dumps(stat_per_model, indent=4))


def process_judge_prompt(
    prompt_id: int,
    pass_id: int,
    llms: Dict[str, LLM],
    test_case: Dict[str, Union[int, str]],
    guided_decode_json_schema: Optional[str] = None,
    logger: Optional[logging.Logger] = None,
) -> Dict[str, Union[str, int]]:
    """
    This function processes each response and evaluates for correctness by checking if the model response matches the correct multiple choice answer.
    Args:
        prompt_id (int): The ID of the prompt.
        pass_id (int): The ID of the pass.
        llms (Dict[str, LLM]): A dummy empty dict that is needed as a parameter to pass this function into query_llm_to_generate_responses.
        test_case (Dict[str, Union[int, str]]): A dictionary representing the test case.
        logger (Optional[logging.Logger], optional): A logger for logging events. Defaults to None.
    Returns:
        Dict[str, Union[int, str]]: A dictionary containing the processed response, including field "answered_correctly".
    """
    response = cast(str, test_case["response"])
    if not response:
        raise ValueError("Response cannot be empty.")

    judge_results: Dict[str, Union[str, int]] = {
        "prompt_id": prompt_id,
        "pass_id": pass_id,
        "options": test_case["options"],
        "question": test_case["question"],
        "correct_options": test_case["correct_options"],
        "model_response": response,
        "topic": test_case["topic"],
        "attack": test_case["attack"],
        "difficulty": test_case["difficulty"],
        "model": test_case["model"],
    }

    parsed_response = extract_json(response)
    if parsed_response:
        parsed_response = cast(Dict[str, str], parsed_response)
        model_answers = cast(List[str], parsed_response.get("correct_answers"))
        sim_score = jaccard_similarity(
            model_answers, cast(List[str], test_case["correct_options"])
        )
        if sim_score == 1.0:
            judge_results["answered_correctly"] = "true"
        else:
            judge_results["answered_correctly"] = "false"
        judge_results["score"] = cast(
            str,
            sim_score,
        )
    else:
        judge_results["answered_correctly"] = "parsing error"

    return judge_results


def jaccard_similarity(list1: list[str], list2: list[str]) -> float:
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union if union > 0 else 0.0
